/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_MATRIX_BATCH_MATMUL_CUSTOM_IMPL_H
#define EXAMPLES_MATRIX_BATCH_MATMUL_CUSTOM_IMPL_H
#include "kernel_operator.h"
#include "lib/matmul_intf.h"

constexpr int USED_CORE_NUM = 2;

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
class BatchMatmulKernel {
    public:
        __aicore__ inline BatchMatmulKernel(){};
        __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling);
        template <bool hasBias = false>
        __aicore__ inline void Process(AscendC::TPipe* pipe, int32_t batchA, int32_t batchB);
        AscendC::Matmul<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE> matmulObj;
    private:
        __aicore__ inline void CalcOffset(int32_t blockIdx, const TCubeTiling& tiling, int32_t& offsetA, int32_t& offsetB,
                                          int32_t& offsetC, int32_t& offsetBias);
        using aType = typename A_TYPE::T;
        using bType = typename B_TYPE::T;
        using cType = typename C_TYPE::T;
        using biasType = typename BIAS_TYPE::T;
        AscendC::GlobalTensor<aType> aGlobal;
        AscendC::GlobalTensor<bType> bGlobal;
        AscendC::GlobalTensor<cType> cGlobal;
        AscendC::GlobalTensor<biasType> biasGlobal;
        TCubeTiling tiling;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias,
                                         GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling)
{
    this->tiling = tiling;
    int32_t sizeA = tiling.ALayoutInfoB * tiling.ALayoutInfoS * tiling.ALayoutInfoN * tiling.ALayoutInfoG * tiling.ALayoutInfoD * sizeof(aType);
    int32_t sizeB = tiling.BLayoutInfoB * tiling.BLayoutInfoS * tiling.BLayoutInfoN * tiling.BLayoutInfoG * tiling.BLayoutInfoD * sizeof(bType);
    int32_t sizeC = tiling.CLayoutInfoB * tiling.CLayoutInfoS1 * tiling.CLayoutInfoN * tiling.CLayoutInfoG * tiling.CLayoutInfoS2 * sizeof(cType);
    int32_t sizeBias = tiling.CLayoutInfoB * tiling.CLayoutInfoN * tiling.CLayoutInfoG * tiling.CLayoutInfoS2 * sizeof(cType);

    aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ aType*>(a), sizeA);
    bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ bType*>(b), sizeB);
    cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(c), sizeC);
    biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ biasType*>(bias), sizeBias);

    int32_t offsetA = 0;
    int32_t offsetB = 0;
    int32_t offsetC = 0;
    int32_t offsetBias = 0;
    CalcOffset(AscendC::GetBlockIdx(), tiling, offsetA, offsetB, offsetC, offsetBias);
    aGlobal = aGlobal[offsetA];
    bGlobal = bGlobal[offsetB];
    cGlobal = cGlobal[offsetC];
    biasGlobal = biasGlobal[offsetBias];
    if (GetSysWorkSpacePtr() == nullptr) {
        return;
    }
}

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
template <bool hasBias>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::Process(AscendC::TPipe* pipe, int32_t batchA, int32_t batchB)
{
    int batchC = batchA > batchB ? batchA : batchB;
    int gLay = tiling.ALayoutInfoG > tiling.BLayoutInfoG ? tiling.ALayoutInfoG : tiling.BLayoutInfoG;
    int forExent = (tiling.ALayoutInfoB / USED_CORE_NUM) * tiling.ALayoutInfoN * gLay / tiling.BatchNum; // cut multi cores from batch axis 
    for (int i = 0; i < forExent; ++i) {
        int batchOffsetA = i * tiling.ALayoutInfoD * batchA;
        int batchOffsetB = i * tiling.BLayoutInfoD * batchB;
        if (tiling.BatchNum == tiling.ALayoutInfoN * tiling.ALayoutInfoG) {
            batchOffsetA = i * tiling.ALayoutInfoD * tiling.ALayoutInfoS * batchA;
        }
        if (tiling.BatchNum == tiling.BLayoutInfoN * tiling.BLayoutInfoG) {
            batchOffsetB = i * tiling.BLayoutInfoD * tiling.BLayoutInfoS * batchB;
        }
        matmulObj.SetTensorA(aGlobal[batchOffsetA], false);
        matmulObj.SetTensorB(bGlobal[batchOffsetB], true); // B transpose

        int idxC = i * batchC;
        if constexpr (hasBias) {
            int batchOffsetBias = idxC * tiling.CLayoutInfoS2;
            matmulObj.SetBias(biasGlobal[batchOffsetBias]);
        }

        int batchOffsetC = idxC * tiling.CLayoutInfoS2;
        if (tiling.BatchNum == tiling.CLayoutInfoN * tiling.CLayoutInfoG) {
            batchOffsetC = idxC * tiling.CLayoutInfoS2 * tiling.CLayoutInfoS1;
        }
        matmulObj.IterateBatch(cGlobal[batchOffsetC], batchA, batchB, false);
        AscendC::PipeBarrier<PIPE_ALL>();
    }
}

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::CalcOffset(int32_t blockIdx, const TCubeTiling& param,
                                        int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias)
{
    int singleCoreBatch = 1; // Batch axis nums after cut on singleCore
    if constexpr (A_TYPE::layout == LayoutMode::BSNGD) {
        // cut multi cores from batch axis when layout is BSNGD
        offsetA = blockIdx * singleCoreBatch * param.ALayoutInfoS * param.ALayoutInfoN * param.ALayoutInfoG * param.ALayoutInfoD;
    }
 
    if constexpr (B_TYPE::layout == LayoutMode::BSNGD) {
        // cut multi cores from batch axis when layout is BSNGD
        offsetB = blockIdx * singleCoreBatch * param.BLayoutInfoS * param.BLayoutInfoN * param.BLayoutInfoG * param.BLayoutInfoD;
    }

    if constexpr (C_TYPE::layout == LayoutMode::BSNGD) {
        // cut multi cores from batch axis when layout is BSNGD
        offsetC = blockIdx * singleCoreBatch * param.CLayoutInfoS1 * param.CLayoutInfoN * param.CLayoutInfoG * param.CLayoutInfoS2;
        offsetBias = blockIdx * singleCoreBatch * param.CLayoutInfoN * param.CLayoutInfoG * param.CLayoutInfoS2;
    }
}
#endif // EXAMPLES_MATRIX_BATCH_MATMUL_CUSTOM_IMPL_H